import sys
from pathlib import Path

# Add project root to sys.path to allow running tests from this directory
# The project root is 6 levels up from the parent directory of this file.
project_root = str(Path(__file__).resolve().parents[6])
if project_root not in sys.path:
    sys.path.insert(0, project_root)

#
# Copyright © 2024 Agora
# This file is part of TEN Framework, an open source project.
# Licensed under the Apache License, Version 2.0, with certain conditions.
# Refer to the "LICENSE" file in the root directory for more information.
#
from pathlib import Path
import json
from typing import Any
from unittest.mock import patch, AsyncMock, MagicMock
import tempfile
import os
import asyncio
import filecmp
import shutil
import threading
import base64

from ten_runtime import (
    ExtensionTester,
    TenEnvTester,
    Cmd,
    CmdResult,
    StatusCode,
    Data,
    TenError,
)
from ten_ai_base.struct import TTSTextInput, TTSFlush, TTS2HttpResponseEventType


# ================ test dump file functionality ================
class ExtensionTesterDump(ExtensionTester):
    def __init__(self):
        super().__init__()
        # Use a fixed path as requested by the user.
        self.dump_dir = "./dump/"
        # Use a unique name for the file generated by the test to avoid collision
        # with the file generated by the extension.
        self.test_dump_file_path = os.path.join(
            self.dump_dir, "test_hume_manual_dump.pcm"
        )
        self.audio_end_received = False
        self.received_audio_chunks = []

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        """Called when test starts, sends a TTS request."""
        ten_env_tester.log_info("Dump test started, sending TTS request.")

        tts_input = TTSTextInput(
            request_id="tts_request_dump",
            text="hello world, testing audio dump functionality",
            text_input_end=True,
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        if name == "tts_audio_end":
            ten_env.log_info("Received tts_audio_end, stopping test.")
            self.audio_end_received = True
            ten_env.stop_test()

    def on_audio_frame(self, ten_env: TenEnvTester, audio_frame):
        """Receives audio frames and collects their data."""
        buf = audio_frame.lock_buf()
        try:
            copied_data = bytes(buf)
            self.received_audio_chunks.append(copied_data)
        finally:
            audio_frame.unlock_buf(buf)

    def write_test_dump_file(self):
        """Writes the collected audio chunks to a file."""
        os.makedirs(self.dump_dir, exist_ok=True)
        with open(self.test_dump_file_path, "wb") as f:
            for chunk in self.received_audio_chunks:
                f.write(chunk)

    def find_tts_dump_file(self) -> str | None:
        """Find the dump file created by the TTS extension in the fixed dump directory."""
        if not os.path.exists(self.dump_dir):
            return None
        for filename in os.listdir(self.dump_dir):
            if filename.endswith(".pcm") and filename == os.path.basename(
                self.test_dump_file_path
            ):
                return os.path.join(self.dump_dir, filename)
        return None


@patch("humeai_tts_python.humeTTS.AsyncHumeClient")
def test_dump_functionality(MockHumeClient):
    """Tests that the dump file from the TTS extension matches the audio received by the test extension."""
    print("Starting test_dump_functionality with mock...")

    # --- Directory Setup ---
    # As requested, use a fixed './dump/' directory.
    DUMP_PATH = "./dump/"

    # Clean up directory before the test, in case of previous failed runs.
    if os.path.exists(DUMP_PATH):
        shutil.rmtree(DUMP_PATH)
    os.makedirs(DUMP_PATH)

    # --- Mock Configuration ---
    mock_client = MockHumeClient.return_value
    mock_client.clean = AsyncMock()

    # Create some fake audio data to be streamed
    fake_audio_chunk_1 = b"\x11\x22\x33\x44" * 20  # 80 bytes
    fake_audio_chunk_2 = b"\xaa\xbb\xcc\xdd" * 20  # 80 bytes

    # This async generator simulates the Hume TTS client's response
    async def mock_tts_stream(
        context=None, utterances=None, format=None, instant_mode=None
    ):
        # First chunk
        mock_snippet_1 = MagicMock()
        mock_snippet_1.generation_id = "test_gen_id"
        mock_snippet_1.audio = base64.b64encode(fake_audio_chunk_1).decode(
            "utf-8"
        )
        mock_snippet_1.is_last_chunk = False
        yield mock_snippet_1

        # Second chunk
        mock_snippet_2 = MagicMock()
        mock_snippet_2.generation_id = "test_gen_id"
        mock_snippet_2.audio = base64.b64encode(fake_audio_chunk_2).decode(
            "utf-8"
        )
        mock_snippet_2.is_last_chunk = True
        yield mock_snippet_2

    mock_client.tts.synthesize_json_streaming = mock_tts_stream

    # --- Test Setup ---
    tester = ExtensionTesterDump()

    dump_config = {
        "params": {
            "key": "test_api_key",
            "voice_name": "Female English Actor",
            "provider": "HUME_AI",
            "speed": 1.0,
            "trailing_silence": 0.0,
            "request_timeout_seconds": 10,
        },
        "dump": True,
        "dump_path": DUMP_PATH,
    }

    tester.set_test_mode_single("humeai_tts_python", json.dumps(dump_config))

    print("Running dump test...")
    tester.run()
    print("Dump test completed.")

    # --- Verification ---
    # 1. Verify audio end was received
    assert tester.audio_end_received, "Expected to receive tts_audio_end"
    assert (
        len(tester.received_audio_chunks) > 0
    ), "Expected to receive audio chunks"

    # 2. Write received audio chunks to test file for comparison
    tester.write_test_dump_file()

    # 3. Find the dump file created by the extension
    tts_dump_file = tester.find_tts_dump_file()
    assert (
        tts_dump_file is not None
    ), f"Expected to find a TTS dump file in {DUMP_PATH}"
    assert os.path.exists(
        tts_dump_file
    ), f"TTS dump file should exist: {tts_dump_file}"

    # 4. Compare the files
    print(
        f"Comparing test file {tester.test_dump_file_path} with TTS dump file {tts_dump_file}"
    )
    assert filecmp.cmp(
        tester.test_dump_file_path, tts_dump_file, shallow=False
    ), "Test dump file and TTS dump file should have the same content"

    print(
        f"✅ Dump functionality test passed: received {len(tester.received_audio_chunks)} audio chunks"
    )
    print(f"   Test file: {tester.test_dump_file_path}")
    print(f"   TTS dump file: {tts_dump_file}")

    # --- Cleanup ---
    if os.path.exists(DUMP_PATH):
        shutil.rmtree(DUMP_PATH)


# ================ test text_input_end logic ================
class ExtensionTesterTextInputEnd(ExtensionTester):
    def __init__(self):
        super().__init__()
        self.ten_env: TenEnvTester | None = None
        self.first_request_audio_end_received = False
        self.second_request_error_received = False
        self.error_code = None
        self.error_message = None
        self.error_module = None

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        self.ten_env = ten_env_tester
        ten_env_tester.log_info(
            "TextInputEnd test started, sending first TTS request."
        )

        # 1. Send first request with text_input_end=True
        tts_input_1 = TTSTextInput(
            request_id="tts_request_1",
            text="hello world, hello agora",
            text_input_end=True,
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input_1.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def send_second_request(self):
        """Sends the second TTS request that should be ignored."""
        if self.ten_env is None:
            return

        self.ten_env.log_info("Sending second TTS request, expecting an error.")
        # 2. Send second request with text_input_end=False, which should be ignored
        tts_input_2 = TTSTextInput(
            request_id="tts_request_1",
            text="this should be ignored",
            text_input_end=True,
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input_2.model_dump_json())
        self.ten_env.send_data(data)

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        ten_env.log_info(f"Received data: {name}")

        if name == "tts_audio_end":
            json_str, _ = data.get_property_to_json(None)
            payload = json.loads(json_str) if json_str else {}
            ten_env.log_info(f"Received tts_audio_end: {payload}")
            if (
                payload.get("request_id") == "tts_request_1"
                and not self.first_request_audio_end_received
            ):
                ten_env.log_info(
                    "Received tts_audio_end for the first request."
                )
                self.first_request_audio_end_received = True
                self.send_second_request()  # Now, send the second request that should fail
            return

        json_str, _ = data.get_property_to_json(None)
        if not json_str:
            return

        payload = json.loads(json_str)
        request_id = payload.get("id")

        if name == "error" and request_id == "tts_request_1":
            ten_env.log_info(
                f"Received expected error for the second request: {payload}"
            )
            self.second_request_error_received = True
            self.error_code = payload.get("code")
            self.error_message = payload.get("message")
            self.error_module = payload.get("module")
            ten_env.stop_test()


# @patch("humeai_tts_python.extension.HumeAiTTS")
# def test_text_input_end_logic(MockHumeAiTTS):
#     """
#     Tests that after a request with text_input_end=True is processed,
#     subsequent requests with the same request_id are ignored and trigger an error.
#     """
#     print("Starting test_text_input_end_logic with mock...")

#     # --- Mock Configuration ---
#     mock_instance = MockHumeAiTTS.return_value
#     mock_instance.cancel = AsyncMock()

#     async def mock_get_audio_stream(text: str):
#         yield (b"\x11\x22\x33", EVENT_TTS_RESPONSE)
#         yield (None, EVENT_TTS_END)

#     mock_instance.get.side_effect = mock_get_audio_stream

#     # --- Test Setup ---
#     config = {"key": "test_api_key", "voice_id": "daisy"}
#     tester = ExtensionTesterTextInputEnd()
#     tester.set_test_mode_single("humeai_tts_python", json.dumps(config))

#     print("Running text_input_end logic test...")
#     tester.run()
#     print("text_input_end logic test completed.")

#     # --- Assertions ---
#     assert (
#         tester.first_request_audio_end_received
#     ), "Did not receive tts_audio_end for the first request."
#     assert (
#         tester.second_request_error_received
#     ), "Did not receive the expected error for the second request."
#     assert (
#         tester.error_code == 1000
#     ), f"Expected error code 1000, but got {tester.error_code}"
#     assert (
#         tester.error_message is not None
#         and "Received a message for a finished request_id"
#         in tester.error_message
#     ), "Error message is not as expected."

#     print("✅ Text input end logic test passed successfully.")


# ================ test flush logic ================
class ExtensionTesterFlush(ExtensionTester):
    def __init__(self):
        super().__init__()
        self.ten_env: TenEnvTester | None = None
        self.audio_start_received = False
        self.first_audio_frame_received = False
        self.flush_start_received = False
        self.audio_end_received = False
        self.flush_end_received = False
        self.audio_end_reason = ""
        self.total_audio_duration_from_event = 0
        self.received_audio_bytes = 0
        self.sample_rate = 48000  # Hume TTS sample rate
        self.bytes_per_sample = 2  # 16-bit
        self.channels = 1
        self.audio_received_after_flush_end = False

    def on_start(self, ten_env_tester: TenEnvTester) -> None:
        self.ten_env = ten_env_tester
        ten_env_tester.log_info("Flush test started, sending long TTS request.")
        tts_input = TTSTextInput(
            request_id="tts_request_for_flush",
            text="This is a very long text designed to generate a continuous stream of audio, providing enough time to send a flush command.",
            text_input_end=True,
        )
        data = Data.create("tts_text_input")
        data.set_property_from_json(None, tts_input.model_dump_json())
        ten_env_tester.send_data(data)
        ten_env_tester.on_start_done()

    def on_audio_frame(self, ten_env: TenEnvTester, audio_frame):
        if self.flush_end_received:
            ten_env.log_error("Received audio frame after tts_flush_end!")
            self.audio_received_after_flush_end = True

        if not self.first_audio_frame_received:
            self.first_audio_frame_received = True
            ten_env.log_info("First audio frame received, sending flush data.")
            flush_data = Data.create("tts_flush")
            flush_data.set_property_from_json(
                None,
                TTSFlush(flush_id="tts_request_for_flush").model_dump_json(),
            )
            ten_env.send_data(flush_data)

        buf = audio_frame.lock_buf()
        try:
            self.received_audio_bytes += len(buf)
        finally:
            audio_frame.unlock_buf(buf)

    def on_data(self, ten_env: TenEnvTester, data) -> None:
        name = data.get_name()
        ten_env.log_info(f"on_data name: {name}")

        if name == "tts_audio_start":
            self.audio_start_received = True
            return

        json_str, _ = data.get_property_to_json(None)
        if not json_str:
            return
        payload = json.loads(json_str)
        ten_env.log_info(f"on_data payload: {payload}")

        if name == "tts_flush_start":
            self.flush_start_received = True
            return

        if name == "tts_audio_end":
            self.audio_end_received = True
            self.audio_end_reason = payload.get("reason")
            self.total_audio_duration_from_event = payload.get(
                "request_total_audio_duration_ms"
            )

        elif name == "tts_flush_end":
            self.flush_end_received = True

            def stop_test_later():
                ten_env.log_info("Waited after flush_end, stopping test now.")
                ten_env.stop_test()

            timer = threading.Timer(0.5, stop_test_later)
            timer.start()

    def get_calculated_audio_duration_ms(self) -> int:
        duration_sec = self.received_audio_bytes / (
            self.sample_rate * self.bytes_per_sample * self.channels
        )
        return int(duration_sec * 1000)


@patch("humeai_tts_python.extension.HumeAiTTS")
def test_flush_logic(MockHumeAiTTS):
    """
    Tests that sending a flush command during TTS streaming correctly stops
    the audio and sends the appropriate events.
    """
    print("Starting test_flush_logic with mock...")

    mock_instance = MockHumeAiTTS.return_value
    mock_instance.cancel = AsyncMock()
    mock_instance.clean = AsyncMock()

    async def mock_get_long_audio_stream(text: str, request_id: str):
        for _ in range(20):
            # In a real scenario, the cancel() call would set a flag.
            # We simulate this by checking the mock's 'called' status.
            if mock_instance.cancel.called:
                print("Mock detected cancel call, sending EVENT_TTS_FLUSH.")
                yield (None, TTS2HttpResponseEventType.FLUSH)
                return  # Stop the generator immediately after flush
            yield (b"\x11\x22\x33" * 100, TTS2HttpResponseEventType.RESPONSE)
            await asyncio.sleep(0.1)

        # This part is only reached if not cancelled - normal completion
        yield (None, TTS2HttpResponseEventType.END)

    mock_instance.get.side_effect = mock_get_long_audio_stream

    config = {
        "params": {
            "key": "test_api_key",
            "voice_name": "Female English Actor",
        },
    }
    tester = ExtensionTesterFlush()
    tester.set_test_mode_single("humeai_tts_python", json.dumps(config))

    print("Running flush logic test...")
    tester.run()
    print("Flush logic test completed.")

    assert tester.audio_start_received, "Did not receive tts_audio_start."
    assert tester.first_audio_frame_received, "Did not receive any audio frame."
    assert tester.audio_end_received, "Did not receive tts_audio_end."
    assert tester.flush_end_received, "Did not receive tts_flush_end."
    assert (
        not tester.audio_received_after_flush_end
    ), "Received audio after tts_flush_end."

    calculated_duration = tester.get_calculated_audio_duration_ms()
    event_duration = tester.total_audio_duration_from_event
    print(
        f"calculated_duration: {calculated_duration}, event_duration: {event_duration}"
    )
    assert (
        abs(calculated_duration - event_duration) < 10
    ), f"Mismatch in audio duration. Calculated: {calculated_duration}ms, From event: {event_duration}ms"

    print("✅ Flush logic test passed successfully.")
